import os
import json
import sys
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig

# read inference config file
config = "../inference/inference_wan2.1_1.3b.json"
if len(sys.argv) >= 2:
    config = sys.argv[1]
try:
    with open(config, 'r') as f:
        inference_config = json.load(f)
except OSError:
    print(f"FAILED: {config} can't open")

# model path
transformer_paths = inference_config['model']['transformer']
vae_path = inference_config['model']['vae']
text_encoder_path = inference_config['model']['text_encoder']
tokenizer_path = inference_config['model']['tokenizer']

# data path
video_path = inference_config['video_path']
image_path = inference_config['image_path']
output_path = inference_config['output_path']

if not os.path.exists(output_path):
    os.mkdir(output_path)

model_config = []
for model_path in transformer_paths:
    model_config.append(ModelConfig(path=model_path))
model_config.append(ModelConfig(path=vae_path))
model_config.append(ModelConfig(path=text_encoder_path))

pipe = WanVideoPipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="npu",
    model_configs=model_config,
    tokenizer_config=ModelConfig(path=tokenizer_path)
)

pipe.enable_vram_management()

# Depth video -> Video
control_video = VideoData(video_path, height=480, width=832)
video = pipe(
    prompt="两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。",
    negative_prompt="色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走",
    vace_video=control_video,
    seed=1, tiled=True
)
save_video(video, os.path.join(output_path, "video1.mp4"), fps=15, quality=5)

# Reference image -> Video
video = pipe(
    prompt="两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。",
    negative_prompt="色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走",
    vace_reference_image=Image.open(image_path).resize((832, 480)),
    seed=1, tiled=True
)
save_video(video, os.path.join(output_path, "video2.mp4"), fps=15, quality=5)

# Depth video + Reference image -> Video
video = pipe(
    prompt="两只可爱的橘猫戴上拳击手套，站在一个拳击台上搏斗。",
    negative_prompt="色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走",
    vace_video=control_video,
    vace_reference_image=Image.open(image_path).resize((832, 480)),
    seed=1, tiled=True
)
save_video(video, os.path.join(output_path, "video3.mp4"), fps=15, quality=5)
